
import os
import cv2
import shutil
import json
import numpy as np


base_path = r"C:\dataset\A2RL\ws_lane_wg_240413"


def read_tu_simple():
    file = open(base_path + r'\label_data_Yas_Grey_240406.json', 'r')
    image_num = 0

    for index, line in enumerate(file.readlines()):
        data = json.loads(line)
        image = cv2.imread(os.path.join(base_path, data['raw_file']))
        binary_image = np.zeros((image.shape[0], image.shape[1], 1), np.uint8)

        instance_image = binary_image.copy()
        arr_width = data['lanes']
        arr_height = data['h_samples']
        width_num = len(arr_width)
        height_num = len(arr_height)
        if width_num == 2:
            print(index)

        for i in range(height_num):  # 水平采样的线条数
            lane_hist = 40
            for j in range(width_num):
                if arr_width[j][i - 1] > 0 and arr_width[j][i] > 0:
                    binary_image[int(arr_height[i]), int(arr_width[j][i])] = 255
                    instance_image[int(arr_height[i]), int(arr_width[j][i])] = lane_hist
                    if i > 0:
                        cv2.line(
                            binary_image,
                            (int(arr_width[j][i - 1]), int(arr_height[i - 1])),
                            (int(arr_width[j][i]), int(arr_height[i])),
                            (255, 255, 255),
                            1
                        )
                        cv2.line(
                            instance_image,
                            (int(arr_width[j][i - 1]), int(arr_height[i - 1])),
                            (int(arr_width[j][i]), int(arr_height[i])),
                            (lane_hist, lane_hist, lane_hist),
                            1
                        )
                lane_hist += 50
        cv2.imshow("binary_image", binary_image)
        cv2.imshow("instance_image", instance_image)
        cv2.imshow("image", image)
        cv2.waitKey(100)


if __name__ == '__main__':
    # read_tu_simple()
    mask = cv2.imread(r'C:\dataset\archive\TUSimple\train_set\seg_label\0313-1\60\20.png')
    mask[mask == 2] = 50
    mask[mask == 3] = 100
    mask[mask == 4] = 150
    img = cv2.imread(r'C:\dataset\archive\TUSimple\train_set\clips\0313-1\60\20.jpg')
    cv2.imshow("img", img)
    cv2.imshow("mask", mask)
    cv2.waitKey()
